{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1f83f273",
   "metadata": {},
   "source": [
    "# SageMaker\n",
    "\n",
    "Let's load the `SageMaker Endpoints Embeddings` class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n",
    "\n",
    "For instructions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker). \n",
    "\n",
    "**Note**: In order to handle batched requests, you will need to adjust the return line in the `predict_fn()` function within the custom `inference.py` script:\n",
    "\n",
    "Change from\n",
    "\n",
    "`return {\"vectors\": sentence_embeddings[0].tolist()}`\n",
    "\n",
    "to:\n",
    "\n",
    "`return {\"vectors\": sentence_embeddings.tolist()}`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88d366bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install langchain boto3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1e9b926a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from typing import Dict, List\n",
    "\n",
    "from langchain.embeddings import SagemakerEndpointEmbeddings\n",
    "from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler\n",
    "\n",
    "\n",
    "class ContentHandler(EmbeddingsContentHandler):\n",
    "    content_type = \"application/json\"\n",
    "    accepts = \"application/json\"\n",
    "\n",
    "    def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n",
    "        \"\"\"\n",
    "        Transforms the input into bytes that can be consumed by SageMaker endpoint.\n",
    "        Args:\n",
    "            inputs: List of input strings.\n",
    "            model_kwargs: Additional keyword arguments to be passed to the endpoint.\n",
    "        Returns:\n",
    "            The transformed bytes input.\n",
    "        \"\"\"\n",
    "        # Example: inference.py expects a JSON string with a \"inputs\" key:\n",
    "        input_str = json.dumps({\"inputs\": inputs, **model_kwargs})\n",
    "        return input_str.encode(\"utf-8\")\n",
    "\n",
    "    def transform_output(self, output: bytes) -> List[List[float]]:\n",
    "        \"\"\"\n",
    "        Transforms the bytes output from the endpoint into a list of embeddings.\n",
    "        Args:\n",
    "            output: The bytes output from SageMaker endpoint.\n",
    "        Returns:\n",
    "            The transformed output - list of embeddings\n",
    "        Note:\n",
    "            The length of the outer list is the number of input strings.\n",
    "            The length of the inner lists is the embedding dimension.\n",
    "        \"\"\"\n",
    "        # Example: inference.py returns a JSON string with the list of\n",
    "        # embeddings in a \"vectors\" key:\n",
    "        response_json = json.loads(output.read().decode(\"utf-8\"))\n",
    "        return response_json[\"vectors\"]\n",
    "\n",
    "\n",
    "content_handler = ContentHandler()\n",
    "\n",
    "\n",
    "embeddings = SagemakerEndpointEmbeddings(\n",
    "    # credentials_profile_name=\"credentials-profile-name\",\n",
    "    endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\",\n",
    "    region_name=\"us-east-1\",\n",
    "    content_handler=content_handler,\n",
    ")\n",
    "\n",
    "\n",
    "# client = boto3.client(\n",
    "#     \"sagemaker-runtime\",\n",
    "#     region_name=\"us-west-2\"\n",
    "# )\n",
    "# embeddings = SagemakerEndpointEmbeddings(\n",
    "#     endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\",\n",
    "#     client=client\n",
    "#     content_handler=content_handler,\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe9797b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_result = embeddings.embed_query(\"foo\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "76f1b752",
   "metadata": {},
   "outputs": [],
   "source": [
    "doc_results = embeddings.embed_documents([\"foo\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fff99b21",
   "metadata": {},
   "outputs": [],
   "source": [
    "doc_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aaad49f8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "7377c2ccc78bc62c2683122d48c8cd1fb85a53850a1b1fc29736ed39852c9885"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
